-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[MXNet-1334][Fit API]base class for estimator and eventhandler #14346
Conversation
@mxnet-label-bot add [Gluon, pr-work-in-progress] |
metric.reset() | ||
|
||
for i, batch in enumerate(train_data): | ||
data, label = self._batch_fn(batch, self.context) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @roywei, in NVIDIA DALI we return already sliced data, could you make _batch_fn
optional?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @ptrendx sounds good!any examples on multi-gpu? would love to check it out.
self.loss = [loss] | ||
else: | ||
self.loss = loss or [] | ||
if isinstance(metrics, EvalMetric): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this does not work for a list of metrics?
>>> class A:
... def __init(self):
... pass
...
>>> A
<class __main__.A at 0x721d0fe88>
>>> x = A
>>> x = [A(), A()]
>>> isinstance(x, A)
False
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, it accepts single metric object or list of metrics(in the else part)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
regardless of whether a single metric or a list of them are passed, we should have this validation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@nswamy updated
self.context = [cpu()] | ||
|
||
# initialize the network | ||
if self.initializer: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
suggest to move the initializing and checking of initializing into one method.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
?
label = gluon.utils.split_and_load(batch[1], ctx_list=ctx, batch_axis=0) | ||
return data, label | ||
|
||
def fit(self, train_data, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: train_data
-> train_dataloader
, ..
if valid_data
is not passed are you going to split the train data with some percentage. if not i think we should
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
currently dataloader does not support split into train and val after it's created. User can create that split outside estimator using a random sampler . Same behavior with pytorch . Keras allows this but only because the fit method is accepting numpy arrays directly and it can be slow.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
val_data
is missing in the signature? Also it should not be optional if you expect user to pass?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please rename to train_dataloader
batch_size=None, | ||
event_handlers=None): | ||
|
||
if not batch_size: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
log.info( with the batch size you are using).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will be added in train_begin in loggin handler with the change on train_stats in a follow up PR
self.logger.addHandler(filehandler) | ||
|
||
def train_begin(self): | ||
pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
might be a good idea to log all the hyper params and defaults used here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will be added with the change on train_stats in a follow up PR
pass | ||
# logger.info(opt) | ||
|
||
def train_end(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
output informational message like number of epochs run, train_loss and valid_loss , other metrics at the end, etc.,
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will be added with the change on train_stats in a follow up PR
self.trainers = [trainers] | ||
else: | ||
self.trainers = trainers or [] | ||
if not self.trainers: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are we dealing with multiple trainer case over here (e.g. Multi task classification)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will be addressed in another PR (TODO item 4)
handler.batch_end() | ||
|
||
for metric in self.metrics + self.loss_metrics: | ||
self.train_stats['train_' + metric.name].append(metric.get()[1]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Validation metrics would also be needed in train_stats. How are we dealing with it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's tracked here, will be address in follow up PR
@nswamy @ankkhedia I have addressed comments, added checkpoint and early stopping handler, could you take another look? Thanks Remaining todo list from your comments, will be addressed in follow up PRs
|
@mxnet-label-bot update[Gluon, pr-awaiting-review] |
elif isinstance(train_data, DataIter): | ||
data, label = self._batch_fn(batch, self.context, is_iterator=True) | ||
else: | ||
raise ValueError("You are using a custom iteration, please also provide " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit : custom iterator
early_stopping = [event_handler.EarlyStoppingHandler(est, monitor, | ||
patience=patience, | ||
mode=mode)] | ||
est.fit(test_data, event_handlers=early_stopping, epochs=1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No assertions
here ?
epoch = self._estimator.train_stats['epochs'][-1] | ||
msg = '\n[Epoch %d] finished in %.3fs: ' % (epoch, epoch_time) | ||
for key in self._estimator.train_stats.keys(): | ||
if key.startswith('train_') or key.startswith('test_'): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
key should be val_
?
if isinstance(metrics, EvalMetric): | ||
self.metrics = [metrics] | ||
else: | ||
self.metrics = metrics or [] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we allow anything but EvalMetric, if not please validate each metric in the list is a EvalMetric.
label = gluon.utils.split_and_load(batch[1], ctx_list=ctx, batch_axis=0) | ||
return data, label | ||
|
||
def fit(self, train_data, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
val_data
is missing in the signature? Also it should not be optional if you expect user to pass?
self.context = [cpu()] | ||
|
||
# initialize the network | ||
if self.initializer: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
?
label = gluon.utils.split_and_load(batch[1], ctx_list=ctx, batch_axis=0) | ||
return data, label | ||
|
||
def fit(self, train_data, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please rename to train_dataloader
est = estimator.Estimator(net, loss=ce_loss, metrics=acc) | ||
logging_handler = [event_handler.LoggingHandler(est, file_name=file_name, file_location=tmpdir)] | ||
est.fit(test_data, event_handlers=logging_handler, epochs=1) | ||
assert os.path.isfile(output_dir) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you want to do a sanity check that validation accuracy for last epoch is logged(also that training process completed?
batch_time = time.time() - self.batch_start | ||
epoch = self._estimator.train_stats['epochs'][-1] | ||
step = self._estimator.train_stats['step'] | ||
msg = '[Epoch %d] [Step %s] time/step: %.3fs ' % (epoch, step, batch_time) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we use the term step
, lets not introduce new terms lets continue to use batch
, look at existing logging
self._estimator.net.save_parameters(self.filepath) | ||
|
||
|
||
class EarlyStoppingHandler(EventHandler): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you please see if we can have a integration test for this?
I will merge this code since its on a branch, please make sure to address all comments in subsequent PRs before a PR to the master |
…e#14346) * base class for estimator and eventhandler * add license * add event handlers * fix pylint * improve arg check * fix pylint * add unit tests
…e#14346) * base class for estimator and eventhandler * add license * add event handlers * fix pylint * improve arg check * fix pylint * add unit tests
…e#14346) * base class for estimator and eventhandler * add license * add event handlers * fix pylint * improve arg check * fix pylint * add unit tests
…e#14346) * base class for estimator and eventhandler * add license * add event handlers * fix pylint * improve arg check * fix pylint * add unit tests
* [MXNet-1334][Fit API]base class for estimator and eventhandler (#14346) * base class for estimator and eventhandler * add license * add event handlers * fix pylint * improve arg check * fix pylint * add unit tests * Fixed issue where the estimator was printing beyond the dataset size … (#14464) * Fixed issue where the estimator was printing beyond the dataset size for the last batch * Added comments * Nudge to CI * [MXNet-1349][Fit API]Add validation support and unit tests for fit() API (#14442) * added estimator unittests * add more tests for estimator * added validation logic * added error handlers, unittests * improve val stats * fix pylint * fix pylint * update unit test * fix tests * fix tests * updated metrics, val logic * trigger ci * trigger ci * update metric, batch_fn error handler * update context logic, add default metric * [MXNet-1340][Fit API]Update train stats (#14494) * add train history * update history * update test * avoid calling empty methods * remove train history object * fix pylint * add unit test * fix test * update categorize handlers * [MXNet-1375][Fit API]Added RNN integration test for fit() API (#14547) * Added RNN integration test for fit() API * Addressed review comments: change in JenkinFile, tmp directory, ctx with condense if/else, renamed imports * CPU test doesn't require nvidiadocker container * Modified the structure by removing the redundant code * [MXNet-1343][Fit API]Add CNN integration test for fit() API (#14405) * added cnn intg tests for fit api * updated cnn intg tests * added functions for nightly test * updated runtime_function * updated intg tests * updated init, datapath, refs * added validation data * update cpu test * refactor code * updated context * [MXNET-1344, 1346][FIT API] Retrieve Batch size and Logging verbose support for Gluon fit() API (#14587) * Retrieve Batch size and Logging verbose support for Gluon fit() API * NIT changes * Addressed review comments: shifted the batch size code to a separate method, sentence correction * Modified unittest * removed redundant parameter * Resolve CI test failure * only support DataLoader for now, future PRs will include DataIter to DataLoader converter * Get the number of samples from shape attribute instead of length due to low space complexity * Simplified batch size retrieval code * removed batch_size parameter from fit() method and fixed the tests * Verbose exception handling * Assigning constant to a verbose * Modified exception message * Resolved undefined class reference * Addressed review comments: Modified verbose level names, docs, variable names * Update estimator.py * move estimator to contrib (#14633) * move to gluon contrib (#14635) * [Fit API] improve event handlers (#14685) * improve event handlers * update tests * passing weakref of estimator * fix unit test * fix test * fix pylint * fix test * fix pylint * move default metric logic * combine nightly tests * [MXNET-1396][Fit-API] Update default handler logic (#14765) * move to nightly for binaries * update default handler * fix pylint * trigger ci * trigger ci * [Fit API] update estimator (#14849) * address comments * add comment * check available context * fix bug * change cpu check * [Fit-API] Adress PR comments (#14885) * address comments * update checkpoint * test symbol save * address comments * add resume * update doc and resume checkpoint * update docs * trigger ci * trigger ci
* base class for estimator and eventhandler * add license * add event handlers * fix pylint * improve arg check * fix pylint * add unit tests
* [MXNet-1334][Fit API]base class for estimator and eventhandler (apache#14346) * base class for estimator and eventhandler * add license * add event handlers * fix pylint * improve arg check * fix pylint * add unit tests * Fixed issue where the estimator was printing beyond the dataset size … (apache#14464) * Fixed issue where the estimator was printing beyond the dataset size for the last batch * Added comments * Nudge to CI * [MXNet-1349][Fit API]Add validation support and unit tests for fit() API (apache#14442) * added estimator unittests * add more tests for estimator * added validation logic * added error handlers, unittests * improve val stats * fix pylint * fix pylint * update unit test * fix tests * fix tests * updated metrics, val logic * trigger ci * trigger ci * update metric, batch_fn error handler * update context logic, add default metric * [MXNet-1340][Fit API]Update train stats (apache#14494) * add train history * update history * update test * avoid calling empty methods * remove train history object * fix pylint * add unit test * fix test * update categorize handlers * [MXNet-1375][Fit API]Added RNN integration test for fit() API (apache#14547) * Added RNN integration test for fit() API * Addressed review comments: change in JenkinFile, tmp directory, ctx with condense if/else, renamed imports * CPU test doesn't require nvidiadocker container * Modified the structure by removing the redundant code * [MXNet-1343][Fit API]Add CNN integration test for fit() API (apache#14405) * added cnn intg tests for fit api * updated cnn intg tests * added functions for nightly test * updated runtime_function * updated intg tests * updated init, datapath, refs * added validation data * update cpu test * refactor code * updated context * [MXNET-1344, 1346][FIT API] Retrieve Batch size and Logging verbose support for Gluon fit() API (apache#14587) * Retrieve Batch size and Logging verbose support for Gluon fit() API * NIT changes * Addressed review comments: shifted the batch size code to a separate method, sentence correction * Modified unittest * removed redundant parameter * Resolve CI test failure * only support DataLoader for now, future PRs will include DataIter to DataLoader converter * Get the number of samples from shape attribute instead of length due to low space complexity * Simplified batch size retrieval code * removed batch_size parameter from fit() method and fixed the tests * Verbose exception handling * Assigning constant to a verbose * Modified exception message * Resolved undefined class reference * Addressed review comments: Modified verbose level names, docs, variable names * Update estimator.py * move estimator to contrib (apache#14633) * move to gluon contrib (apache#14635) * [Fit API] improve event handlers (apache#14685) * improve event handlers * update tests * passing weakref of estimator * fix unit test * fix test * fix pylint * fix test * fix pylint * move default metric logic * combine nightly tests * [MXNET-1396][Fit-API] Update default handler logic (apache#14765) * move to nightly for binaries * update default handler * fix pylint * trigger ci * trigger ci * [Fit API] update estimator (apache#14849) * address comments * add comment * check available context * fix bug * change cpu check * [Fit-API] Adress PR comments (apache#14885) * address comments * update checkpoint * test symbol save * address comments * add resume * update doc and resume checkpoint * update docs * trigger ci * trigger ci
…e#14346) * base class for estimator and eventhandler * add license * add event handlers * fix pylint * improve arg check * fix pylint * add unit tests
Description
This is the first PR for the Gluon Fit API proposed here: Gluon Fit API tech design
It add base classes for estimator and eventhandlers, and an example. We would like to merge into fit-api branch so other controbutors can work on top of it before making it available on master branch.
JIRA epic can be found here: https://issues.apache.org/jira/browse/MXNET-1333
We are separating this into a few PRs, all tracked by JIRA issues under the epic.
We will open follow up PRs for the following tasks:
Checklist
Essentials
Please feel free to remove inapplicable items for your PR.
Changes
Comments
@ankkhedia